The third option when training the 30 species of cropped flowers is to use a pre-trained model. Below, we have used the MobileNet V2 model (https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/3) and a feature vector which gives all layers except the last one (https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/3).
import numpy as np
import cv2
import PIL.Image as Image
import os
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pylab as pl
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras.utils import to_categorical
import pathlib
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam,SGD,Adagrad,Adadelta,RMSprop
from tensorflow.keras.layers import Dense, Dropout, Flatten, Activation, Conv2D, MaxPooling2D
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback, ReduceLROnPlateau
import itertools
from sklearn.metrics import confusion_matrix
x = [] # Feature dataset with images
y = [] # Target dataset with labels
folder_dir = './Flower-Data_CNN'
size = (224,224) #Crop the image to 224x224
for folder in os.listdir(folder_dir):
for file in os.listdir(os.path.join(folder_dir, folder)):
if file.endswith("JPG"):
y.append(folder)
img = cv2.imread(os.path.join(folder_dir, folder, file))
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
im = cv2.resize(img_rgb, (size))
x.append(im)
else:
continue
# having the model ready
classifier = tf.keras.Sequential ([
hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/3", input_shape = size +(3,))
])
# split the dataset into 80% training and 20% test sets:
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.20, random_state=10)
# take a look at one of the images:
X_train[0]
array([[[137, 157, 152],
[182, 198, 197],
[185, 199, 199],
...,
[ 9, 14, 11],
[ 19, 28, 27],
[ 22, 34, 32]],
[[185, 204, 198],
[160, 175, 172],
[172, 190, 188],
...,
[ 15, 21, 17],
[ 20, 30, 29],
[ 20, 32, 30]],
[[173, 191, 185],
[172, 195, 189],
[165, 190, 186],
...,
[ 20, 29, 24],
[ 17, 29, 27],
[ 19, 32, 30]],
...,
[[ 70, 107, 136],
[ 66, 103, 133],
[ 62, 100, 128],
...,
[ 57, 95, 98],
[ 72, 105, 110],
[ 98, 129, 139]],
[[ 38, 62, 83],
[ 36, 57, 77],
[ 34, 54, 72],
...,
[ 55, 90, 94],
[ 74, 105, 107],
[ 85, 114, 118]],
[[ 13, 27, 30],
[ 12, 25, 25],
[ 16, 26, 25],
...,
[ 67, 98, 101],
[ 71, 100, 102],
[ 62, 90, 93]]], dtype=uint8)
img = plt.imshow(X_train[0])
y_train[0]
'Southern Marsh Orchid'
y_train = LabelEncoder().fit_transform(y_train)
y_test = LabelEncoder().fit_transform(y_test)
y_train
array([24, 3, 1, ..., 5, 20, 5], dtype=int64)
X_train = np.array(X_train)/255
X_test= np.array(X_test)/255
X_train
array([[[[0.5372549 , 0.61568627, 0.59607843],
[0.71372549, 0.77647059, 0.77254902],
[0.7254902 , 0.78039216, 0.78039216],
...,
[0.03529412, 0.05490196, 0.04313725],
[0.0745098 , 0.10980392, 0.10588235],
[0.08627451, 0.13333333, 0.1254902 ]],
[[0.7254902 , 0.8 , 0.77647059],
[0.62745098, 0.68627451, 0.6745098 ],
[0.6745098 , 0.74509804, 0.7372549 ],
...,
[0.05882353, 0.08235294, 0.06666667],
[0.07843137, 0.11764706, 0.11372549],
[0.07843137, 0.1254902 , 0.11764706]],
[[0.67843137, 0.74901961, 0.7254902 ],
[0.6745098 , 0.76470588, 0.74117647],
[0.64705882, 0.74509804, 0.72941176],
...,
[0.07843137, 0.11372549, 0.09411765],
[0.06666667, 0.11372549, 0.10588235],
[0.0745098 , 0.1254902 , 0.11764706]],
...,
[[0.2745098 , 0.41960784, 0.53333333],
[0.25882353, 0.40392157, 0.52156863],
[0.24313725, 0.39215686, 0.50196078],
...,
[0.22352941, 0.37254902, 0.38431373],
[0.28235294, 0.41176471, 0.43137255],
[0.38431373, 0.50588235, 0.54509804]],
[[0.14901961, 0.24313725, 0.3254902 ],
[0.14117647, 0.22352941, 0.30196078],
[0.13333333, 0.21176471, 0.28235294],
...,
[0.21568627, 0.35294118, 0.36862745],
[0.29019608, 0.41176471, 0.41960784],
[0.33333333, 0.44705882, 0.4627451 ]],
[[0.05098039, 0.10588235, 0.11764706],
[0.04705882, 0.09803922, 0.09803922],
[0.0627451 , 0.10196078, 0.09803922],
...,
[0.2627451 , 0.38431373, 0.39607843],
[0.27843137, 0.39215686, 0.4 ],
[0.24313725, 0.35294118, 0.36470588]]],
[[[0.38823529, 0.45882353, 0.12941176],
[0.37647059, 0.44705882, 0.11764706],
[0.37254902, 0.44313725, 0.11764706],
...,
[0.51372549, 0.56470588, 0.29411765],
[0.56078431, 0.61176471, 0.34117647],
[0.57647059, 0.62745098, 0.35686275]],
[[0.38823529, 0.45882353, 0.12941176],
[0.38039216, 0.45098039, 0.12156863],
[0.37647059, 0.44705882, 0.12156863],
...,
[0.54117647, 0.59215686, 0.31764706],
[0.55294118, 0.60392157, 0.33333333],
[0.5254902 , 0.57647059, 0.30588235]],
[[0.38039216, 0.45098039, 0.12156863],
[0.37647059, 0.44705882, 0.11764706],
[0.37647059, 0.44705882, 0.12156863],
...,
[0.55294118, 0.60392157, 0.33333333],
[0.5254902 , 0.57254902, 0.30980392],
[0.46666667, 0.51764706, 0.25490196]],
...,
[[0.33333333, 0.36862745, 0.2 ],
[0.3254902 , 0.35686275, 0.19215686],
[0.31372549, 0.3372549 , 0.18823529],
...,
[0.29803922, 0.35686275, 0.14509804],
[0.29019608, 0.34901961, 0.1372549 ],
[0.28235294, 0.34117647, 0.13333333]],
[[0.31764706, 0.34901961, 0.19215686],
[0.30588235, 0.3372549 , 0.18431373],
[0.29411765, 0.32156863, 0.18039216],
...,
[0.28627451, 0.35294118, 0.14117647],
[0.27843137, 0.34509804, 0.13333333],
[0.2745098 , 0.34117647, 0.1254902 ]],
[[0.30980392, 0.34117647, 0.18823529],
[0.29411765, 0.3254902 , 0.17647059],
[0.28235294, 0.30980392, 0.17254902],
...,
[0.2745098 , 0.34117647, 0.12941176],
[0.27058824, 0.3372549 , 0.1254902 ],
[0.26666667, 0.33333333, 0.12156863]]],
[[[0.38431373, 0.52156863, 0.2627451 ],
[0.37254902, 0.50980392, 0.25882353],
[0.35686275, 0.49411765, 0.24313725],
...,
[0.64313725, 0.74509804, 0.56078431],
[0.63529412, 0.74509804, 0.61960784],
[0.62745098, 0.74117647, 0.63137255]],
[[0.41960784, 0.55686275, 0.29803922],
[0.40392157, 0.54117647, 0.28235294],
[0.38431373, 0.52156863, 0.2627451 ],
...,
[0.64313725, 0.74509804, 0.56862745],
[0.64313725, 0.74509804, 0.63137255],
[0.63529412, 0.73333333, 0.64313725]],
[[0.43921569, 0.57647059, 0.30980392],
[0.43137255, 0.56862745, 0.30196078],
[0.41960784, 0.55686275, 0.29019608],
...,
[0.64313725, 0.74901961, 0.58823529],
[0.63137255, 0.73333333, 0.63529412],
[0.62352941, 0.72156863, 0.63921569]],
...,
[[0.35294118, 0.43137255, 0.29411765],
[0.4 , 0.48627451, 0.3372549 ],
[0.42745098, 0.51372549, 0.36078431],
...,
[0.36078431, 0.50196078, 0.2627451 ],
[0.35686275, 0.50588235, 0.26666667],
[0.36470588, 0.51372549, 0.2745098 ]],
[[0.35686275, 0.45098039, 0.30196078],
[0.4 , 0.49803922, 0.34117647],
[0.42745098, 0.5254902 , 0.36862745],
...,
[0.4 , 0.50588235, 0.30980392],
[0.36862745, 0.49803922, 0.2745098 ],
[0.36862745, 0.50196078, 0.26666667]],
[[0.37647059, 0.4745098 , 0.32156863],
[0.41568627, 0.51372549, 0.35686275],
[0.42745098, 0.5254902 , 0.36862745],
...,
[0.38823529, 0.49019608, 0.30588235],
[0.38823529, 0.50588235, 0.28235294],
[0.38823529, 0.50980392, 0.27843137]]],
...,
[[[0.61176471, 0.67843137, 0.4745098 ],
[0.70588235, 0.72156863, 0.58431373],
[0.61568627, 0.64705882, 0.47843137],
...,
[0.81960784, 0.14901961, 0.10980392],
[0.82352941, 0.16862745, 0.11764706],
[0.83137255, 0.19215686, 0.1254902 ]],
[[0.62745098, 0.69803922, 0.49019608],
[0.7254902 , 0.75294118, 0.59215686],
[0.58039216, 0.63921569, 0.42745098],
...,
[0.79215686, 0.12156863, 0.07843137],
[0.82352941, 0.16078431, 0.10588235],
[0.83529412, 0.18823529, 0.12156863]],
[[0.65098039, 0.70196078, 0.50196078],
[0.70196078, 0.7372549 , 0.56078431],
[0.50196078, 0.59215686, 0.33333333],
...,
[0.79215686, 0.09411765, 0.05490196],
[0.81176471, 0.13333333, 0.08235294],
[0.82352941, 0.16862745, 0.10588235]],
...,
[[0.50980392, 0.58039216, 0.38039216],
[0.49803922, 0.58823529, 0.3372549 ],
[0.55686275, 0.65490196, 0.40784314],
...,
[0.83921569, 0.03921569, 0.04313725],
[0.83529412, 0.03529412, 0.03137255],
[0.83529412, 0.03921569, 0.02352941]],
[[0.45098039, 0.5254902 , 0.32156863],
[0.49803922, 0.58823529, 0.34117647],
[0.50980392, 0.60392157, 0.34901961],
...,
[0.83137255, 0.03137255, 0.03529412],
[0.83529412, 0.03529412, 0.03529412],
[0.83137255, 0.03137255, 0.02745098]],
[[0.45882353, 0.53333333, 0.32941176],
[0.48627451, 0.57254902, 0.3372549 ],
[0.52156863, 0.61176471, 0.34117647],
...,
[0.83529412, 0.03529412, 0.04705882],
[0.82745098, 0.02745098, 0.03529412],
[0.83137255, 0.02745098, 0.02745098]]],
[[[0.49803922, 0.61960784, 0.25098039],
[0.49803922, 0.63137255, 0.24313725],
[0.48627451, 0.62745098, 0.23137255],
...,
[0.28627451, 0.32941176, 0.14509804],
[0.28627451, 0.36470588, 0.11372549],
[0.35294118, 0.45490196, 0.16470588]],
[[0.49803922, 0.62352941, 0.24313725],
[0.49803922, 0.63137255, 0.25098039],
[0.48627451, 0.62352941, 0.24705882],
...,
[0.29803922, 0.34509804, 0.14117647],
[0.3254902 , 0.4 , 0.14117647],
[0.37647059, 0.48235294, 0.17254902]],
[[0.50980392, 0.63529412, 0.25490196],
[0.50196078, 0.63137255, 0.25098039],
[0.49803922, 0.63137255, 0.25882353],
...,
[0.29411765, 0.34901961, 0.1372549 ],
[0.36078431, 0.44705882, 0.18039216],
[0.38823529, 0.49411765, 0.18431373]],
...,
[[0.40784314, 0.48235294, 0.20784314],
[0.45882353, 0.5372549 , 0.28235294],
[0.67058824, 0.74901961, 0.51764706],
...,
[0.27058824, 0.39607843, 0.05098039],
[0.25882353, 0.38039216, 0.05882353],
[0.24705882, 0.35686275, 0.06666667]],
[[0.69411765, 0.77254902, 0.55294118],
[0.7254902 , 0.80392157, 0.59215686],
[0.69411765, 0.78039216, 0.57647059],
...,
[0.28627451, 0.38431373, 0.05882353],
[0.2745098 , 0.36470588, 0.06666667],
[0.24705882, 0.33333333, 0.0627451 ]],
[[0.70980392, 0.78431373, 0.59607843],
[0.69411765, 0.77647059, 0.59215686],
[0.69803922, 0.78431373, 0.59215686],
...,
[0.25882353, 0.31764706, 0.07843137],
[0.21568627, 0.26666667, 0.05882353],
[0.19215686, 0.23921569, 0.0627451 ]]],
[[[0.59607843, 0.69803922, 0.44313725],
[0.50588235, 0.61960784, 0.36862745],
[0.49803922, 0.61960784, 0.37647059],
...,
[0.23137255, 0.26666667, 0.15686275],
[0.68235294, 0.74509804, 0.58039216],
[0.20392157, 0.29803922, 0.0745098 ]],
[[0.49019608, 0.63137255, 0.34901961],
[0.48235294, 0.62352941, 0.35294118],
[0.4745098 , 0.61568627, 0.35686275],
...,
[0.22745098, 0.26666667, 0.15686275],
[0.65098039, 0.71372549, 0.55294118],
[0.18823529, 0.27843137, 0.07058824]],
[[0.48235294, 0.63529412, 0.36470588],
[0.48627451, 0.62745098, 0.37254902],
[0.50980392, 0.63137255, 0.39607843],
...,
[0.25882353, 0.29411765, 0.18431373],
[0.6627451 , 0.72156863, 0.56078431],
[0.21960784, 0.29803922, 0.10196078]],
...,
[[0.62352941, 0.69019608, 0.41568627],
[0.39215686, 0.4745098 , 0.21176471],
[0.35686275, 0.42352941, 0.21176471],
...,
[0.21960784, 0.22745098, 0.09803922],
[0.20392157, 0.21960784, 0.09411765],
[0.21960784, 0.24705882, 0.11372549]],
[[0.7372549 , 0.81176471, 0.55294118],
[0.47058824, 0.55686275, 0.30196078],
[0.42352941, 0.50196078, 0.27843137],
...,
[0.65882353, 0.68235294, 0.5254902 ],
[0.7254902 , 0.74901961, 0.59607843],
[0.50588235, 0.5372549 , 0.38431373]],
[[0.71764706, 0.78431373, 0.59215686],
[0.54901961, 0.63921569, 0.4 ],
[0.4627451 , 0.56078431, 0.29803922],
...,
[0.56470588, 0.59607843, 0.43921569],
[0.56862745, 0.59607843, 0.46666667],
[0.56470588, 0.60392157, 0.47058824]]]])
# the feature vector gives all the layers except the last one
feature_extractor_model = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/3"
# trainable false means freeze, do not train (all the layers will have their fixed weights)
pretrained_model_without_top_layer = hub.KerasLayer(
feature_extractor_model, input_shape=(224, 224, 3), trainable=False)
num_of_flowers = 30
# create the last layer
model = tf.keras.Sequential([
pretrained_model_without_top_layer,
tf.keras.layers.Dense(num_of_flowers)
])
model.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= keras_layer_1 (KerasLayer) (None, 1280) 2257984 _________________________________________________________________ dense (Dense) (None, 30) 38430 ================================================================= Total params: 2,296,414 Trainable params: 38,430 Non-trainable params: 2,257,984 _________________________________________________________________
checkpoint = ModelCheckpoint(
'model.h5',
monitor = 'val_loss',
verbose = 1,
save_best_only = True)
reduce_lr = ReduceLROnPlateau(
monitor = 'val_loss',
factor = 0.2,
verbose = 1,
patience = 5,
min_lr = 0.001)
# Train the model
model.compile(optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
history = model.fit(X_train, y_train, epochs= 10, validation_split=0.2, verbose = 1, callbacks = [reduce_lr, checkpoint])
Epoch 1/10 134/134 [==============================] - 66s 464ms/step - loss: 1.0840 - accuracy: 0.7687 - val_loss: 0.3940 - val_accuracy: 0.9355 Epoch 00001: val_loss improved from inf to 0.39403, saving model to model.h5 Epoch 2/10 134/134 [==============================] - 56s 420ms/step - loss: 0.2491 - accuracy: 0.9652 - val_loss: 0.2267 - val_accuracy: 0.9636 Epoch 00002: val_loss improved from 0.39403 to 0.22670, saving model to model.h5 Epoch 3/10 134/134 [==============================] - 56s 417ms/step - loss: 0.1450 - accuracy: 0.9818 - val_loss: 0.1767 - val_accuracy: 0.9692 Epoch 00003: val_loss improved from 0.22670 to 0.17670, saving model to model.h5 Epoch 4/10 134/134 [==============================] - 57s 428ms/step - loss: 0.0988 - accuracy: 0.9920 - val_loss: 0.1393 - val_accuracy: 0.9729 Epoch 00004: val_loss improved from 0.17670 to 0.13933, saving model to model.h5 Epoch 5/10 134/134 [==============================] - 61s 454ms/step - loss: 0.0708 - accuracy: 0.9946 - val_loss: 0.1252 - val_accuracy: 0.9757 Epoch 00005: val_loss improved from 0.13933 to 0.12519, saving model to model.h5 Epoch 6/10 134/134 [==============================] - 57s 429ms/step - loss: 0.0551 - accuracy: 0.9970 - val_loss: 0.1098 - val_accuracy: 0.9794 Epoch 00006: val_loss improved from 0.12519 to 0.10976, saving model to model.h5 Epoch 7/10 134/134 [==============================] - 56s 421ms/step - loss: 0.0428 - accuracy: 0.9986 - val_loss: 0.1011 - val_accuracy: 0.9804 Epoch 00007: val_loss improved from 0.10976 to 0.10108, saving model to model.h5 Epoch 8/10 134/134 [==============================] - 56s 421ms/step - loss: 0.0353 - accuracy: 0.9984 - val_loss: 0.0971 - val_accuracy: 0.9785 Epoch 00008: val_loss improved from 0.10108 to 0.09709, saving model to model.h5 Epoch 9/10 134/134 [==============================] - 57s 427ms/step - loss: 0.0290 - accuracy: 1.0000 - val_loss: 0.0937 - val_accuracy: 0.9804 Epoch 00009: val_loss improved from 0.09709 to 0.09371, saving model to model.h5 Epoch 10/10 134/134 [==============================] - 56s 421ms/step - loss: 0.0241 - accuracy: 0.9998 - val_loss: 0.0885 - val_accuracy: 0.9813 Epoch 00010: val_loss improved from 0.09371 to 0.08850, saving model to model.h5
print(history.history.keys())
dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy', 'lr'])
#accuracy
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend(['Train','Val'],loc='upper left')
plt.show()
#Loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(['Train','Val'],loc='upper left')
plt.show()
# Prediction
predict_model = model.predict(np.array(X_train))
predict_model = np.argmax(predict_model, axis=1)
predict_model
array([24, 3, 1, ..., 5, 20, 5], dtype=int64)
cm = confusion_matrix(y_true = y_train, y_pred = predict_model)
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=90)
plt.yticks(tick_marks, classes)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j],
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
cm_plot_labels = ['Bird\'s-foot Trefoil', 'Brown Knapweed', 'Buttercup', 'Chamomile', 'Common Dandelion', 'Common Poppy', 'Cornflower', 'Cow Parsley',
'Cuckooflower', 'Field Mouse-ear', 'Flatweed', 'Hares-foot clover', 'Health Spotted Orchid', 'Hoary Alyssum', 'Lesser Spearwort',
'Marsh Lousewort', 'Marsh marigold', 'Meadow Thistle', 'Ox-eye Daisy', 'Perforate St John\'s-wort', 'Purple Loosestrife','Ragwort', 'Red Clover', 'Redstem Filaree',
'Southern Marsh Orchid', 'Tansy', 'White Clover', 'Wild Carrot', 'Yarrow', 'Yellow Loosestrife']
fig, ax = plt.subplots(figsize=(12, 12))
plot_confusion_matrix(cm=cm, classes=cm_plot_labels, title='Confusion Matrix')
Confusion matrix, without normalization
[[177 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0]
[ 0 190 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0
0 0 0 0 0 0 0 0 0 0 0 0]
[ 1 0 194 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0
0 0 0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 159 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 1 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 176 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 187 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 194 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 187 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 165 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 182 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 1 0 0 0 0 0 185 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 159 0 1 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 1 0 0 0 174 1 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 192 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 198 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 177 0 0
0 0 0 0 0 0 0 0 0 0 0 0]
[ 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 166 0
0 0 0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 176
0 0 0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
183 0 0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 178 0 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 183 0 0 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 174 0 0 0 0 0 0 0 0]
[ 0 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
0 0 0 0 124 0 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 1 174 0 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 177 0 0 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 187 0 0 0 0]
[ 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 172 0 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 186 0 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 174 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 176]]
y_test_one_hot = tf.one_hot(y_test, depth=30)
categories = np.sort(os.listdir(folder_dir))
size = 224
fig, ax = plt.subplots(6,6, figsize=(25, 40))
for i in range(6):
for j in range(6):
k = int(np.random.random_sample() * len(X_test))
if(categories[np.argmax(y_test_one_hot[k])] == categories[np.argmax(model.predict(X_test)[k])]):
ax[i,j].set_title("TRUE: " + categories[np.argmax(y_test_one_hot[k])], color='green')
ax[i,j].set_xlabel("PREDICTED: " + categories[np.argmax(model.predict(X_test)[k])], color='green')
ax[i,j].imshow(np.array(X_test)[k].reshape(size, size, 3), cmap='gray')
else:
ax[i,j].set_title("TRUE: " + categories[np.argmax(y_test_one_hot[k])], color='red')
ax[i,j].set_xlabel("PREDICTED: " + categories[np.argmax(model.predict(X_test)[k])], color='red')
ax[i,j].imshow(np.array(X_test)[k].reshape(size, size, 3), cmap='gray')